# coding:utf-8
import ast
import sqlite3


# power by SQLite
class DB:
    def __init__(self, db_server_url):
        self.db_server_url = db_server_url

        # 创建sql链接 若无该名称db则创建
        self._db_connect = sqlite3.connect(db_server_url)

        # 创建游标
        cur = self._db_connect.cursor()
        # 创建原始数据表
        cur.execute('''
                create table if not exists dt_chain(
                    rid  integer  primary key autoincrement,
                    height  integer,
                    block_hash  character(70),
                    tx_id   character(70),
                    tx_type  character(10),
                    data_id  character(70),
                    business_id  character(110),
                    public_key_hash  character(40),
                    source_node_address  character(140),
                    block_value  character(10000)
                );
                ''')

        self._db_connect.commit()
        self._db_connect.close()

    def create(self, block_hash, dic):
        hd_dic = dic.get('block_header', {})
        tx_dic = dic.get('transaction', {})

        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()

        cur.execute("INSERT INTO dt_chain values(?,?,?,?,?,?,?,?,?,?)",
                    (None,
                     hd_dic.get('height', -1),
                     block_hash,
                     tx_dic.get('tx_id', ''),
                     tx_dic.get('tx_type', ''),
                     tx_dic.get('tx_content', {}).get('data', {}).get('data_id', ''),
                     tx_dic.get('business_id', ''),
                     tx_dic.get('tx_content', {}).get('public_key_hash', ''),
                     tx_dic.get('tx_content', {}).get('source_node_address', ''),
                     str(dic)
                     )
                    )
        db_connect.commit()
        db_connect.close()
        return block_hash

    def get(self, block_hash):
        return self.__getitem__(block_hash)

    def find(self, dic):
        condition = 'SELECT block_value FROM dt_chain where'
        value = []
        for k, v in dic.items():
            if condition[-1] == 'e':
                condition += (' ' + k + '=?')
            else:
                condition += (' and' + k + '=?')
            value.append(v)

        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()

        q_res = cur.execute(condition, tuple(value)).fetchall()
        db_connect.close()

        res = []
        for r in q_res:
            res.append(ast.literal_eval(r[0]))

        return res

    def update(self, update_list):
        dic = update_list[0]
        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()

        cur.execute('UPDATE dt_chain SET block_value=? WHERE block_hash=?', (str(dic), 'l'))

        db_connect.commit()
        db_connect.close()

    # @return: list[record1[], record2[], ...]
    def show(self):
        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()
        db_res = cur.execute('SELECT height, block_hash, tx_id, tx_type, data_id, business_id, public_key_hash, source_node_address FROM dt_chain').fetchall()
        db_connect.close()
        return db_res

    def clean_up(self):
        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()
        cur.execute('delete from dt_chain')
        db_connect.commit()
        db_connect.close()

    def __getattr__(self, name):
        return

    def __contains__(self, name):
        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()

        res = cur.execute('SELECT height FROM dt_chain WHERE block_hash=?', (name,)).fetchall()
        db_connect.close()
        if len(res) > 0 and len(res[0]) > 0:
            return True
        else:
            return False

    def __getitem__(self, key):
        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()

        res = cur.execute('SELECT block_value FROM dt_chain WHERE block_hash=?', (key,)).fetchall()
        db_connect.close()
        if len(res) > 0 and len(res[0]) > 0:
            return ast.literal_eval(res[0][0])
        else:
            return {}

    def __setitem__(self, key, value):
        db_connect = sqlite3.connect(self.db_server_url)
        cur = db_connect.cursor()

        res = cur.execute('SELECT height FROM dt_chain WHERE block_hash=?', (key,)).fetchall()

        if len(res) > 0 and len(res[0]) > 0:
            cur.execute('UPDATE dt_chain SET block_value=? WHERE block_hash=?', (str(value), key))
        else:
            dic = value
            hd_dic = dic.get('block_header', {})
            tx_dic = dic.get('transaction', {})
            cur.execute("INSERT INTO dt_chain values(?,?,?,?,?,?,?,?,?,?)",
                        (None,
                         hd_dic.get('height', -1),
                         key,
                         tx_dic.get('tx_id', ''),
                         tx_dic.get('tx_type', ''),
                         tx_dic.get('tx_content', {}).get('data', {}).get('data_id', ''),
                         tx_dic.get('business_id', ''),
                         tx_dic.get('tx_content', {}).get('public_key_hash', ''),
                         tx_dic.get('tx_content', {}).get('source_node_address', ''),
                         str(dic)
                         )
                        )

        db_connect.commit()
        db_connect.close()
